Find duplicate subtrees

Time: O(N); Space: O(N); medium

Given a binary tree, return all duplicate subtrees. For each kind of duplicate subtrees, you only need to return the root node of any one of them. Two trees are duplicate if they have the same structure with same node values.

Example 1:

    1
   / \
  2   3
 /   / \
4   2   4
   /
  4

Input: root = {TreeNode} [1,2,3,4,None,2,4,None,None,4]

Output: {TreeNode} [2,4], {TreeNode} [4]:

  2
 /
4

and

4

Therefore, you need to return above trees’ root in the form of a list.

[16]:
class TreeNode(object):
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None

Auxiliary Tools

[17]:
from graphviz import Graph

class TreeTasks(object):
    def visualize_tree(self, tree):
        def add_nodes_edges(tree, dot=None):
            # Create Graph (not Digraph) object
            if dot is None:
                dot = Graph()
                dot.node(name=str(tree), label=str(tree.val))
            # Add nodes
            if tree.left:
                dot.node(name=str(tree.left), label="."+str(tree.left.val))
                dot.edge(str(tree), str(tree.left))
                dot = add_nodes_edges(tree.left, dot=dot)
            if tree.right:
                dot.node(name=str(tree.right), label=str(tree.right.val)+".")
                dot.edge(str(tree), str(tree.right))
                dot = add_nodes_edges(tree.right, dot=dot)
            return dot
        # Add nodes recursively and create a list of edges
        dot = add_nodes_edges(tree)
        # Visualize the graph
        display(dot)
        return dot

Solution

[18]:
import collections

class Solution1(object):
    def findDuplicateSubtrees(self, root):
        '''
        :type root: TreeNode
        :rtype: List[TreeNode]
        '''
        def getid(root, lookup, trees):
            if root:
                node_id = lookup[root.val, \
                                 getid(root.left, lookup, trees), \
                                 getid(root.right, lookup, trees)]
                trees[node_id].append(root)
                return node_id

        trees = collections.defaultdict(list)
        lookup = collections.defaultdict()
        lookup.default_factory = lookup.__len__

        getid(root, lookup, trees)

        return [roots[0] for roots in trees.values() if len(roots) > 1]
[19]:
s = Solution1()

root = TreeNode(1)
root.left, root.right = TreeNode(2), TreeNode(3)
root.left.left = TreeNode(4)
root.right.left, root.right.right = TreeNode(2), TreeNode(4)
root.right.left.left = TreeNode(4)

t = TreeTasks()
LTN = s.findDuplicateSubtrees(root)
for tree in LTN:
    dot = t.visualize_tree(tree)

# for item in LTN:
#     print(item.val, end='')
#     if item.left:
#         print(' -> Left:', item.left.val)
#     if item.right:
#         print(' - > Right: ', item.right.val)
#     print()
../../_images/topics_tree_0652_find_duplicate_subtrees_[O(N),O(N),med]_6_0.svg
../../_images/topics_tree_0652_find_duplicate_subtrees_[O(N),O(N),med]_6_1.svg
[20]:
class Solution2(object):
    '''
    Time: O(N*H)
    Space: O(N*H)
    '''
    def findDuplicateSubtrees(self, root):
        '''
        :type root: TreeNode
        :rtype: List[TreeNode]
        '''
        def postOrderTraversal(node, lookup, result):
            if not node:
                return ""
            s = "(" + postOrderTraversal(node.left, lookup, result) + \
                str(node.val) + \
                postOrderTraversal(node.right, lookup, result) + \
                ")"
            if lookup[s] == 1:
                result.append(node)
            lookup[s] += 1
            return s

        lookup = collections.defaultdict(int)
        result = []
        postOrderTraversal(root, lookup, result)
        return result
[21]:
s = Solution2()

root = TreeNode(1)
root.left, root.right = TreeNode(2), TreeNode(3)
root.left.left = TreeNode(4)
root.right.left, root.right.right = TreeNode(2), TreeNode(4)
root.right.left.left = TreeNode(4)

t = TreeTasks()
LTN = s.findDuplicateSubtrees(root)
for tree in LTN:
    dot = t.visualize_tree(tree)
../../_images/topics_tree_0652_find_duplicate_subtrees_[O(N),O(N),med]_8_0.svg
../../_images/topics_tree_0652_find_duplicate_subtrees_[O(N),O(N),med]_8_1.svg